{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Simple Train\n",
    "\n",
    "This Notebook is written to provide a self-contained overview over the data and network loading and training\n",
    "\n",
    "It is organized in three parts\n",
    "\n",
    "1. Creation of a fake dataset with random values\n",
    "2. Loading this dataset using the Tensorflows [Dataset API](https://www.tensorflow.org/api_docs/python/tf/data/Dataset)\n",
    "3. Performing one training step\n",
    "\n",
    "## load packages"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from S2parser import S2parser\n",
    "import os\n",
    "import gzip\n",
    "import shutil\n",
    "import tensorflow as tf\n",
    "\n",
    "parser=S2parser()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Create a fake random dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def write_random_sample(filepath,nobs=46, pix10=24, bands10=4, pix20=12, bands20=6, pix60=6, bands60=3):\n",
    "    \"\"\"create a fake sample with random values\"\"\"\n",
    "    x10 = (np.random.random([nobs,pix10,pix10,bands10])*1e4).astype(np.int64)\n",
    "    x20 = (np.random.random([nobs,pix20,pix20,bands20])*1e4).astype(np.int64)\n",
    "    x60 = (np.random.random([nobs,pix60,pix60,bands60])*1e4).astype(np.int64)\n",
    "    doy = (np.random.random([nobs])*365).astype(np.int64)\n",
    "    year = np.round(np.random.random([nobs])+2016.5).astype(np.int64)\n",
    "    label = (np.random.random([nobs,pix10,pix10])*17).astype(np.int64)\n",
    "\n",
    "    # create instance of parser\n",
    "    parser = S2parser()\n",
    "\n",
    "    # write .tfrecord\n",
    "    parser.write(filepath, x10, x20, x60, doy, year, label)\n",
    "\n",
    "def ziptfrecord(infile,outfile):\n",
    "    # gzip .tfrecord to .tfrecord.gz\n",
    "    with open(infile, 'rb') as f_in:\n",
    "        with gzip.open(outfile, 'wb') as f_out:\n",
    "            shutil.copyfileobj(f_in, f_out)\n",
    "\n",
    "    # remove unzipped .tfrecord\n",
    "    #os.remove(infile)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['tmp/0.tfrecord', 'tmp/1.tfrecord', 'tmp/2.tfrecord', 'tmp/3.tfrecord', 'tmp/4.tfrecord', 'tmp/5.tfrecord', 'tmp/6.tfrecord', 'tmp/7.tfrecord', 'tmp/8.tfrecord', 'tmp/9.tfrecord']\n",
      "writing tmp/0.tfrecord\n",
      "compressing tmp/0.tfrecord -> tmp/0.tfrecord.gz\n",
      "writing tmp/1.tfrecord\n",
      "compressing tmp/1.tfrecord -> tmp/1.tfrecord.gz\n",
      "writing tmp/2.tfrecord\n",
      "compressing tmp/2.tfrecord -> tmp/2.tfrecord.gz\n",
      "writing tmp/3.tfrecord\n",
      "compressing tmp/3.tfrecord -> tmp/3.tfrecord.gz\n",
      "writing tmp/4.tfrecord\n",
      "compressing tmp/4.tfrecord -> tmp/4.tfrecord.gz\n",
      "writing tmp/5.tfrecord\n",
      "compressing tmp/5.tfrecord -> tmp/5.tfrecord.gz\n",
      "writing tmp/6.tfrecord\n",
      "compressing tmp/6.tfrecord -> tmp/6.tfrecord.gz\n",
      "writing tmp/7.tfrecord\n",
      "compressing tmp/7.tfrecord -> tmp/7.tfrecord.gz\n",
      "writing tmp/8.tfrecord\n",
      "compressing tmp/8.tfrecord -> tmp/8.tfrecord.gz\n",
      "writing tmp/9.tfrecord\n",
      "compressing tmp/9.tfrecord -> tmp/9.tfrecord.gz\n"
     ]
    }
   ],
   "source": [
    "# define directory to store the fake dataset\n",
    "directory=\"tmp\"\n",
    "if not os.path.exists(directory):\n",
    "    os.makedirs(directory)\n",
    "\n",
    "filepaths=[\"{}/{}.tfrecord\".format(directory,i) for i in range(10)]\n",
    "zippedfilepaths=[\"{}/{}.tfrecord.gz\".format(directory,i) for i in range(10)]\n",
    "print(filepaths)\n",
    "\n",
    "# create the dataset\n",
    "for filepath,zippedfilepath in zip(filepaths,zippedfilepaths):\n",
    "    write_random_sample(filepath)\n",
    "    print(\"writing \"+filepath)\n",
    "    ziptfrecord(filepath,zippedfilepath)\n",
    "    print(\"compressing \"+filepath+\" -> \"+zippedfilepath)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create a Tensorflow Dataset Iterator\n",
    "\n",
    "See https://www.tensorflow.org/programmers_guide/datasets for more details\n",
    "\n",
    "- Essentially includes the data in the processing graph as input nodes.\n",
    "- These commands set up the processing graph of the input pipeline.\n",
    "- No data is processed,yet.\n",
    "- Some errors may remain undetected until data is inferred with sess.run(.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Reset all previously loaded graphs (in case this or the next cells have been executed twice)\n",
      "creating dataset object\n",
      "applying the mapping function on all samples (will read tfrecord file and normalize the values)\n",
      "repeat forever until externally stopped\n",
      "combine samples to batches\n",
      "make iterator\n"
     ]
    }
   ],
   "source": [
    "batchsize=2\n",
    "\n",
    "print(\"Reset all previously loaded graphs (in case this or the next cells have been executed twice)\")\n",
    "tf.reset_default_graph()\n",
    "\n",
    "print(\"creating dataset object\")\n",
    "dataset = tf.data.TFRecordDataset(zippedfilepaths, compression_type=\"GZIP\")\n",
    "#dataset = tf.data.TFRecordDataset(filepaths)\n",
    "\n",
    "\n",
    "def normalize(serialized_feature):\n",
    "    \"\"\" normalize stored integer values to floats approx. [0,1] \"\"\"\n",
    "    x10, x20, x60, doy, year, labels = serialized_feature\n",
    "    x10 = tf.scalar_mul(1e-4, tf.cast(x10, tf.float32))\n",
    "    x20 = tf.scalar_mul(1e-4, tf.cast(x20, tf.float32))\n",
    "    x60 = tf.scalar_mul(1e-4, tf.cast(x60, tf.float32))\n",
    "    doy = tf.cast(doy, tf.float32) / 365\n",
    "    year = tf.cast(year, tf.float32) - 2016\n",
    "\n",
    "    return x10, x20, x60, doy, year, labels\n",
    "\n",
    "def mapping_function(serialized_feature):\n",
    "    # read data from .tfrecords\n",
    "    serialized_feature = parser.parse_example(serialized_feature)\n",
    "    return normalize(serialized_feature)\n",
    "\n",
    "print(\"applying the mapping function on all samples (will read tfrecord file and normalize the values)\")\n",
    "dataset = dataset.map(mapping_function)\n",
    "\n",
    "print(\"repeat forever until externally stopped\")\n",
    "dataset = dataset.repeat()\n",
    "\n",
    "print(\"combine samples to batches\")\n",
    "dataset = dataset.apply(tf.contrib.data.batch_and_drop_remainder(int(batchsize)))\n",
    "\n",
    "print(\"make iterator\")\n",
    "iterator = dataset.make_initializable_iterator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3a. Retrieve one sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "retrieving one sample as numpy array (just for fun)\n",
      "x10.shape: (2, 46, 24, 24, 4)\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4xLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvAOZPmwAAGLFJREFUeJzt3Xl01OW9BvDnSxLCDgE0ssoWBFQa\nlFJFVFwQlApKXbAuqFT0uqK4YG2rntJ7PW4ordgLsikipYoCxYIUF8Q9KLIICCLKEnZZQyDL9/6R\nSU+kZJ6RJDPJfZ/POZwkM09+eTPhyW8y8877mrtDRMJTLdEDEJHEUPlFAqXyiwRK5RcJlMovEiiV\nXyRQKr9IoFR++REzu8LMPjSzHDN79wjXZ5rZosj1i8wsMwHDlHKg8svhdgJ4BsBjh19hZtUBzAAw\nGUAagEkAZkQulypG5Q+QmbU1s51mdkrk46Zmts3Merr7v9x9GoBNR/jUngCSATzj7gfdfRQAA3Bu\n3AYv5UblD5C7fwPgAQCTzawWgAkAJrn7u+RTTwSwxH88J3xJ5HKpYpITPQBJDHcfa2YXA/gEgAPo\nF8On1QGw+7DLdgOoW87DkzjQmT9sYwGcBODP7n4whvw+APUOu6wegL3lPTCpeCp/oMysDooe2BsH\n4BEzaxjDpy0H0NnMrMRlnSOXSxWj8ofrWQBZ7v4bALMB/BUAzCzJzGqg6E/CamZWw8xSIp/zLoAC\nAHeaWaqZ3R65/O34Dl3Kg+n1/OExs/4ARgM42d13Ru4FLAbwMIAUFD0AWNIkd78+8rldALwAoBOA\nFQAGu/sX8Rq7lB+VXyRQutsvEiiVXyRQKr9IoFR+kUDFdYZfcs3anlI/+tPJ1XfwuSb5rZNopkZS\nHs3kFfLjFGznr1lJ3nWAZgCg+gn8wdWD62vQTGF1/js7hm8NKfsKaMZb8EzeHn4bVd+eSzO5zVJp\nJmW30Ux+LRpBcmw/MrRuvplmvl+RRjN5rXjVCgr494a86D/7/J07UbB/fwwHinP5U+o3RJvr7oma\nafHSGnqcnaMOn2T2n9qnbaWZrQf4rNTdY1vQTIOZS2kGAFq+WEgza+/rQDP7m/CS5Kbxn3/6Bz/Q\nTP7I/TSzeS6/jVqMX0UzKx5qQzPN5/BffNsy+W++Rsv5zwIAXnziSZq5o9sAmsl+ms+h2ruvJh9Q\ndvSf/caRz/BjRJTpbr+Z9TGzVWa2xsyGl+VYIhJfR11+M0sC8ByAC1E04eMqM+tUXgMTkYpVljN/\nNwBr3H2tux8CMBVA//IZlohUtLKUvxmA9SU+3hC57EfMbIiZZZlZVsEB/vejiMRHhT/V5+5j3L2r\nu3dNqlm7or+ciMSoLOXfCKDkw7zNI5eJSBVQlvJ/BiDDzFpHFnAcCGBm+QxLRCraUT/P7+75kddz\nzwWQBGC8u0dd1CGl/iEcd+H6aBFceAN/PvjphRfQTNpw/q3tPIVPzmh562qauerR2J7nH/eLU2lm\n8Mdv0MyLfXvSzBP/mkIzF8+4m2Ysmz/33K739zTT/Rp+p3Drs+1p5iCf4oEB/RbSzKsnxLbi+J09\nBvJQIz6rKP2P/P/j7kF8spSlk0lwKbHNXwDKOMnH3d8E8GZZjiEiiaG5/SKBUvlFAqXyiwRK5RcJ\nlMovEiiVXyRQKr9IoOK+V181RF/NZvY1PegxWj3OF+pYeecxNNP4QxrBrt+3pJlHh/KvBQAHHmhA\nMw/Oz6CZLuO/pZk3951EM8k5fMGPhQP+QjOX3c4nC004uznN1GrEx+MxrFGTWfs7mplbmy+aAgD5\nTfhEsBte+gfNvHT+GTTT9F3+2pfknOiV3b4jpkV8AOjMLxIslV8kUCq/SKBUfpFAqfwigVL5RQKl\n8osESuUXCVRcJ/nk7a6OTXOiT5o5rg7fR2nKCS/TzIWz7qeZ1N18K6q9LfjuOHVe4SuwAED+5Xtp\n5tcZX9BMVj++s83oh86lmQ5/200zb17ammZqf7ePZgpr1aGZltM20cyOM5rSzIQr+tJMevZ2mgEA\nq8bPj7/7x5U003wc3/ZrVseRNNP/jqFRr/efcDrXmV8kUCq/SKBUfpFAqfwigVL5RQKl8osESuUX\nCZTKLxKouE7yqVYvH7XPib4KT8oLO+hxpu/rSDN72vEJPDlN+e++VwfyiRcf5rSlGQAYt5av5jJz\n3ck0028m3x6s4NRsmjnpU34bTet3Js380I3voXX96QtoZuoTfDuzlk/wbd63j8inmfyChjQDAOkj\nUmimZps9NNMglU9e63vfPfw4H62Nen3SvkP0GMV05hcJlMovEiiVXyRQKr9IoFR+kUCp/CKBUvlF\nAqXyiwQqrpN8ktYcQoNL10fN1Hubb1nUq/ZKmpkxla9kk/TlGpr5uD9fNWfss/1oBgDe//2zNJPj\neTTz6ytvpZkp346mmWd2nEYzQ2bPpZl737iWZj69JIZtyCZvoJlRr/GtsX4+9y6a6fjQ9zQDAKtH\n8pWDUrLq0szBUdH/3wNA3uV8q63Nl0afUJb3Kl95qpjO/CKBKtOZ38zWAdgLoABAvrt3LY9BiUjF\nK4+7/ee4e2yrIYpIpaG7/SKBKmv5HcBbZrbIzIYcKWBmQ8wsy8yy8jy3jF9ORMpLWe/293D3jWZ2\nLIB5ZrbS3X/02k13HwNgDADUq9bIy/j1RKSclOnM7+4bI2+3AngdQLfyGJSIVLyjLr+Z1TazusXv\nA7gAwLLyGpiIVKyy3O1PB/C6mRUfZ4q7z4n2CXnptbDh+uirtayfzb/w+sHv0cymM/lkIR/WimZm\nnMO3mXpgwRSaAYBzh99JMw3nfkMzud35RI5u04fRTMfH+USXdvO30Aya8VVq9nc6lmZ61ufnjqHr\n+VZcJ9yymGZ6fM63TgOAkwuirzwFADNT+OpLK5/qQDNWna/Cc8z86FvDWSE9xL8ddfndfS2Anx3t\n54tIYumpPpFAqfwigVL5RQKl8osESuUXCZTKLxIolV8kUCq/SKDiu4xX3TzUP2dz1EzOP46jx7l1\nyhFfQPgjf7jxbzzzz8tpZkrWCzTT7xa+bBQAJKfy1zXlndCMZup+Gf02BICkK/ledAcz+G29LZ8v\nUfXoqbNo5sX002nmmGQ+625HL77M2Tcjfk4zqz7m+/kBgOXx86Mb/7k+0JNPXZ11Hp8pOHtR1Em0\n6LZ4Gz1GMZ35RQKl8osESuUXCZTKLxIolV8kUCq/SKBUfpFAqfwigTL3+C2o26DDsX7W2CuiZr7b\nlUaPs/drnunafRXNfLK0Hc3UWcPnQeXzlb4AAI2/LKCZDb34z+Ozi0fSzI1rf0UzG19uTTPHvsKX\n1lp/G5+cktOCf+/1VyTRTPJ+fvscszCGpce2/8AzALZczpffangl32MwJy+FZrI38//X/U7+Mur1\nf7t6LrZ8tZNv+ged+UWCpfKLBErlFwmUyi8SKJVfJFAqv0igVH6RQKn8IoGK60o+vsZxqH9u1EzO\nMD7RIeNPX9DMwKV8csqnNVvRTMupG2lmxX0taAYANg3ge7HVrRv99gGAay7hKxnNmfUyzZy/+kaa\n+fp5PhHK8w/STON3ou8xBwB72vIJPB/d8BTNnPncvTTjxlcxAoCDjfnmd9UmN6eZxp/vopl6xvc8\nHDA9K+r1c5Nz6DGK6cwvEiiVXyRQKr9IoFR+kUCp/CKBUvlFAqXyiwRK5RcJVFwn+bQ7cQ9mzHk7\nauakCZ3ocfJO45m61T6jmerratBM9mi+TE/b2nwiEACk3FGTZmr8la8wc8qE9TRz+rBbaGbHED6p\nxLfyMXfo/D3NnHI3H/OuvFo0c0336CtBAUDuY/z7qp3Fvy8ASNrEz4+973qfZhZs4ZOlti9oQjPr\nDjWOev2hwmx6jGL0OzOz8Wa21cyWlbisoZnNM7PVkbd8Wp6IVCqx3O2fCKDPYZcNBzDf3TMAzI98\nLCJVCC2/uy8AsPOwi/sDmBR5fxKAS8p5XCJSwY72Ab90dy/+42IzgPTSgmY2xMyyzCxr2w6+gquI\nxEeZH+33orW/S305lruPcfeu7t71mEZ8aWYRiY+jLf8WM2sCAJG3W8tvSCISD0db/pkABkXeHwRg\nRvkMR0TiJZan+l4B8BGAE8xsg5kNBvAYgF5mthrA+ZGPRaQKiet2XSd3TvHps6NPUri972/ocbb8\nNx/z/qUNaaZa+30002pEPs08/PpkmgGAh27iK/DsOZ6veJO2kq/W8sa0sTTTfyAfz7YufOJNQSqN\noMlHfMw3jud3IB9efDHNfNVjIs2c8tnVNAMAUzPH0czdGT1pZteVp9BMci7/f32wXvTz9coZI7F/\n23pt1yUipVP5RQKl8osESuUXCZTKLxIolV8kUCq/SKBUfpFAxXUlnw2H0vDA99Ff/ZvyHN/WqH8D\nvnLM/HE9aGZHB74Vk23aTjOTd3SnGQBYewX/XXv9ae/RzGX1F9FMl0n30Mzlf1lIM7+st5hmfnfj\nTTRjefwVnQUxnIvmnT6aZjqNv49m6q2lEQDAPVf3oplRa96imUHD+SSfeu/xQVlq9Elga3bzLeGK\n6cwvEiiVXyRQKr9IoFR+kUCp/CKBUvlFAqXyiwRK5RcJVFxX8klt08ybjrgtaqbVBP776KbR02nm\njxOvopncdD7Jp7AGzzSdH9PCKcirxb+3rrd/QTOL/tyFZnLT+Jj2tucTb7w6//77Zi6hmSUjfkYz\nG87nY36s91SaySnkSwulWGzLyGdU30wzD594Ns30/mwTzYydfBHNnPer6NvQTbtmLrZ+tUMr+YhI\n6VR+kUCp/CKBUvlFAqXyiwRK5RcJlMovEiiVXyRQcV3Jp3bqIfyi9bqomY+vbUOPc0zyHprJr8Mn\nL9VosZdmcn6oSTPdH+Qr6wDA+5vb0syTTRbQzPkHM2mmxk4+nrTZfAJPzY18S7P3ep9KMwWd+XiS\ncvjP7KGsS2nmrTP+QjN9XuKr/QBAXhq/jTo05ZtUvzSKT3Jq8elumpnVJvrPfncO//9TTGd+kUCp\n/CKBUvlFAqXyiwRK5RcJlMovEiiVXyRQKr9IoOK6kk/9Wk39tHaDo2Y292xIj3Psp3xyzthXn6eZ\noev4hJHcfnz7o9XDO9EMALTsupFmdv29Gc3szuA/s8xfrKGZWsn8e7s1/R2auXnJNTTzm4wPaWbs\n+L400/Q9PsFr4GS+fdbYPwygGQDI7hFDP+rl04gX8sV1qu3mc+6uOff9qNePG/guNi3fpZV8RKR0\ntPxmNt7MtprZshKXPWJmG81sceQfX3xMRCqVWM78EwH0OcLlI909M/LvzfIdlohUNFp+d18AIIaX\niYhIVVKWv/lvN7MlkT8L0koLmdkQM8sys6xD+fvL8OVEpDwdbfmfB9AWQCaAbABPlRZ09zHu3tXd\nu1ZPrn2UX05EyttRld/dt7h7gbsXAhgLoFv5DktEKtpRld/MmpT48FIAy0rLikjlRGcVmNkrAHoC\naGxmGwA8DKCnmWUCcADrANwc81ckv27qreMTJr65vA7NXPI/fKWWJ+4dQzNPNbuMZjIe+4pmAGB/\njxNoZjff+QmtZx2kmaW5GTTTbEEezTQc+0+aqTWtPs3M+rYnzdz0wmyamXkRXxJo8n/9kmb23cVX\nzQGAzmn8se4D+Sn8QOdtoJH9c/gqVlOW/zzq9TsORN/OqyRafnc/0qZ342L+CiJSKWmGn0igVH6R\nQKn8IoFS+UUCpfKLBErlFwmUyi8SqLhu15XbOAlf39ggaqbd1Bx6nLSOuTRzyy/5tkUPruSruRw6\ntxHNPDX0VZoBgBG38ok3TT+I6VBUvcwdNLOlI5+ccvdZA2mmyUvf0MyyjU1p5p9bT6SZmsl8YtLK\nm/lEsW7HZtMMAOQW8IrM7TiXZnpX41ua3duWr0A05o4Lol6/fVMBPUYxnflFAqXyiwRK5RcJlMov\nEiiVXyRQKr9IoFR+kUCp/CKBiut2XXXaH+eZz10XNbN+A59U8/q5z9HMBwfa8ePc2otm1t5IIzhu\ndnUeAjDtiSdp5sx5Q2mm+ewkminkEWw6l//sj2vFJwsVvnIszTT6ZBvNnPPaFzTzQx5fBPb7A3zL\ntxU7+JgBIL3OPpo5OKIJzaQu/pZmCnb+QDMnZUXfiWvK1fOw5aud2q5LREqn8osESuUXCZTKLxIo\nlV8kUCq/SKBUfpFAqfwigVL5RQIV12W8qmUnIXVE9GW86nbjs+U69uG/s57e3Jpm9j/A92tr/7sa\nNPP89FE0AwC9xt1PM6ee/zXN5DzAZ+bd9jFfD6xNMt+Hru+su2mmSS4fT9/XP6GZv646k2aGdfwX\nzSzbw5cMy2i4nWYAYNM+vg9hSgyTZHf1ak8zZ9zPb6NrG34Y9fq5yXxGYjGd+UUCpfKLBErlFwmU\nyi8SKJVfJFAqv0igVH6RQKn8IoGK6ySfvHRH9t2HomauzVhIj/NRbirNrBrN933LvGMxzbSZyCeD\nXPokn7wDAO/d/wTN9P3yBpopuCyG/QNvaUUzt4zmewxWO8RXhNrVlp9D3tvJJ7l0b8aXujq75lo+\nnsa1aObvj/SmGQAoqM6//zqb+WSp5Bw+WWx5rzSauW/n6VGv31DIJ0EVoz81M2thZu+Y2VdmttzM\n7opc3tDM5pnZ6shbPnIRqTRiudufD2CYu3cCcBqA28ysE4DhAOa7ewaA+ZGPRaSKoOV392x3/zzy\n/l4AKwA0A9AfwKRIbBKASypqkCJS/n7SA35m1gpAFwCfAEh39+JNzjcDSC/lc4aYWZaZZeXvySnD\nUEWkPMVcfjOrA+A1AEPdfU/J67xo8f8jvrbJ3ce4e1d375pcjz8QIyLxEVP5zSwFRcV/2d2nRy7e\nYmZNItc3AbC1YoYoIhUhlkf7DcA4ACvc/ekSV80EMCjy/iAAM8p/eCJSUWJ5nv8MANcCWGpmxU+M\n/xbAYwCmmdlgAN8BuKJihigiFYGW390XAihtpsN5P+mLbauGxmOj/91/9nMr6XFWHuJ7o+U24pMz\nUqvl08yce8+mmVHPj6YZADhr/H00k5TLj5PUh+/plvw4XxFpwoALaSb14vKZBLrkbT7Jp+1oPoGn\n/9V8QtW+zvxGPGfYcpoBgLeXdKSZ3vfzY13XgK/SM+GH6BN4AGDusz2iXl8w4yN6jGKa3isSKJVf\nJFAqv0igVH6RQKn8IoFS+UUCpfKLBErlFwlUXFfyOdQQ+P7KwqiZ9KQD9Dg3/+9lNPPl/XzizXnX\nDqaZ/c1TaOadvZ1oBgDqreX7OjX+lK8c9MptL9LMPU/3opm3l3egmfaDo28PBQC7rz6NZtKW7qKZ\nXRNr00zSLBopfUpaCWv38tWQAOD4GCatL2rfkmYuqLuUZoY05BN0pnTpHvX6/Hn0EP+mM79IoFR+\nkUCp/CKBUvlFAqXyiwRK5RcJlMovEiiVXyRQcZ3kYwcNqWujb7X1YEY/epxmc7bRzM8Kb6WZ9Jz9\nNLOnb/TtxQDg1Zd60gwApO3iKwd5df4jmbyHr4qzcP7JNJPSln//161aTzPNkpfQzOMXDaCZDzpP\np5l9J/FVevqv5CvK7Z/YlGYA4OQ/8Mk5G24+nmYercm3YStMTaKZpIvI+ZrPI/s3nflFAqXyiwRK\n5RcJlMovEiiVXyRQKr9IoFR+kUCp/CKBsqLdteP0xcy2oWhfv2KNAfClayqfqjhujTl+Ejnu4939\nmFiCcS3/f3xxsyx375qwARylqjhujTl+qsq4dbdfJFAqv0igEl3+MQn++kerKo5bY46fKjHuhP7N\nLyKJk+gzv4gkiMovEqiEld/M+pjZKjNbY2bDEzWOn8LM1pnZUjNbbGZZiR5PacxsvJltNbNlJS5r\naGbzzGx15G1aIsd4uFLG/IiZbYzc3ovN7KJEjvFwZtbCzN4xs6/MbLmZ3RW5vFLf1sUSUn4zSwLw\nHIALAXQCcJWZxbbnVeKd4+6Zlfx53IkA+hx22XAA8909A8D8yMeVyUT855gBYGTk9s509zfjPCYm\nH8Awd+8E4DQAt0X+H1f22xpA4s783QCscfe17n4IwFQA/RM0lv933H0BgJ2HXdwfwKTI+5MAXBLX\nQRGljLlSc/dsd/888v5eACsANEMlv62LJar8zQCUXBxuQ+Syys4BvGVmi8xsSKIH8xOlu3t25P3N\nANITOZif4HYzWxL5s6BS3n0GADNrBaALgE9QRW5rPeD30/Rw91NQ9OfKbWZ2VqIHdDS86PndqvAc\n7/MA2gLIBJAN4KnEDufIzKwOgNcADHX3PSWvq8y3daLKvxFAixIfN49cVqm5+8bI260AXkfRny9V\nxRYzawIAkbdbEzweyt23uHuBuxcCGItKeHubWQqKiv+yuxcvP1wlbutElf8zABlm1trMqgMYCGBm\ngsYSEzOrbWZ1i98HcAGAZdE/q1KZCWBQ5P1BAGLYeT6xigsUcSkq2e1tZgZgHIAV7v50iauqxG2d\nsBl+kadtngGQBGC8u/8pIQOJkZm1QdHZHija72BKZR2zmb0CoCeKXlq6BcDDAN4AMA1ASxS9rPoK\nd680D7CVMuaeKLrL7wDWAbi5xN/SCWdmPQC8D2ApgMLIxb9F0d/9lfa2LqbpvSKB0gN+IoFS+UUC\npfKLBErlFwmUyi8SKJVfJFAqv0ig/g8pQnJiQchYYQAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<matplotlib.figure.Figure at 0x7f5b08e5f850>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "\n",
    "with tf.Session() as sess:\n",
    "    sess.run(iterator.initializer)\n",
    "    print(\"retrieving one sample as numpy array (just for fun)\")\n",
    "    x10, x20, x60, doy, year, labels = sess.run(iterator.get_next())\n",
    "    print(\"x10.shape: \" + str(x10.shape))\n",
    "    plt.imshow(x10[0,0,:,:,0])\n",
    "    plt.title(\"x10\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 3b Perform one training step"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loading network graph definition\n",
      "initializing variables, tables and the data iterator\n",
      "getting one string handle from the iterator that can be fed to the network\n",
      "making some processing nodes accessible to python\n",
      "performing one training step\n"
     ]
    }
   ],
   "source": [
    "# define the network to be loaded\n",
    "# if not yet created, make one with python modelzoo/seqencoder.py script, as described in the readme.md\n",
    "graph=\"tmp/convgru128/graph.meta\"\n",
    "\n",
    "with tf.Session() as sess:\n",
    "\n",
    "    print(\"loading network graph definition\")\n",
    "    tf.train.import_meta_graph(graph)\n",
    "\n",
    "    print(\"initializing variables, tables and the data iterator\")\n",
    "    sess.run([tf.global_variables_initializer(), tf.local_variables_initializer(), tf.tables_initializer(),iterator.initializer])\n",
    "\n",
    "    print(\"getting one string handle from the iterator that can be fed to the network\")\n",
    "    iterator_handle = sess.run(iterator.string_handle())\n",
    "\n",
    "    print(\"making some processing nodes accessible to python\")\n",
    "    def get_operation(name):\n",
    "        return tf.get_default_graph().get_operation_by_name(name).outputs[0]\n",
    "\n",
    "    iterator_handle_op = get_operation(\"data_iterator_handle\")\n",
    "    is_train_op = get_operation(\"is_train\")\n",
    "    train_op = get_operation(\"train_op\")\n",
    "\n",
    "    print(\"performing one training step\")\n",
    "    sess.run(train_op,feed_dict={iterator_handle_op: iterator_handle, is_train_op: True})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
